{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "from src.anchor_generator import tile_anchors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "WIDTH, HEIGHT = 512, 1024\n",
    "GRID_WIDTH, GRID_HEIGHT = 16, 32  # stride 32 or scale 0 in the face detector\n",
    "# GRID_WIDTH, GRID_HEIGHT = 8, 16  # stride 64 or scale 1 in the face detector\n",
    "# GRID_WIDTH, GRID_HEIGHT = 4, 8  # stride 128 or scale 2 in the face detector"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate anchors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "n = 4\n",
    "anchors = tile_anchors(\n",
    "    (WIDTH, HEIGHT), GRID_HEIGHT, GRID_WIDTH,\n",
    "    scale=32, aspect_ratio=1.0, \n",
    "    anchor_stride=(1.0/GRID_HEIGHT, 1.0/GRID_WIDTH), \n",
    "    anchor_offset=(0.5/GRID_HEIGHT, 0.5/GRID_WIDTH), \n",
    "    n=n\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.Session() as sess:\n",
    "    anchor_boxes = sess.run(anchors)\n",
    "\n",
    "scaler = np.array([HEIGHT, WIDTH, HEIGHT, WIDTH], dtype='float32')\n",
    "anchor_boxes = anchor_boxes*scaler  # shape [GRID_HEIGHT, GRID_WIDTH, n*n, 4]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Show some non clipped anchors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_grid_centers():\n",
    "    anchor_stride = (1.0/GRID_HEIGHT, 1.0/GRID_WIDTH)\n",
    "    anchor_offset = (0.5/GRID_HEIGHT, 0.5/GRID_WIDTH)\n",
    "\n",
    "    y_center = np.arange(GRID_HEIGHT, dtype='float32') * anchor_stride[0] + anchor_offset[0]\n",
    "    x_center = np.arange(GRID_WIDTH, dtype='float32') * anchor_stride[1] + anchor_offset[1]\n",
    "    x_center, y_center = np.meshgrid(x_center, y_center)\n",
    "    # they have shape [grid_height, grid_width]\n",
    "\n",
    "    centers = np.stack([y_center, x_center], axis=2)\n",
    "    scaler = np.array([HEIGHT, WIDTH], dtype='float32')\n",
    "    centers = centers*scaler\n",
    "    return centers\n",
    "\n",
    "\n",
    "def plot(anchor_boxes, cell_to_show):\n",
    "    fig, ax = plt.subplots(1, dpi=120, figsize=(int(8*WIDTH/HEIGHT), 8))\n",
    "\n",
    "    grid_centers = get_grid_centers()\n",
    "    for point in grid_centers.reshape(-1, 2):\n",
    "        cy, cx = point\n",
    "        ax.plot([cx], [cy], marker='.', markersize=1, color='r')\n",
    "    \n",
    "    iy, ix = cell_to_show\n",
    "    cy, cx = grid_centers[iy, ix, :]\n",
    "    ax.plot([cx], [cy], marker='.', markersize=5, color='r')\n",
    "    \n",
    "    cy, cx, h, w = [anchor_boxes[:, :, :, i] for i in range(4)]\n",
    "    centers = np.stack([cy, cx], axis=3)\n",
    "    anchor_sizes = np.stack([h, w], axis=3)\n",
    "\n",
    "    centers = centers[iy, ix, :, :]\n",
    "    anchor_sizes = anchor_sizes[iy, ix, :, :]\n",
    "    \n",
    "    to_show = [1, 4, 15]\n",
    "    for i, center, box in zip(range(len(centers)), centers, anchor_sizes):\n",
    "        \n",
    "        h, w = box\n",
    "        cy, cx = center\n",
    "        xmin, ymin = cx - 0.5*w, cy - 0.5*h\n",
    "\n",
    "        linestyle = '-' if i in to_show else '--' \n",
    "        random_color = np.random.rand(3,)\n",
    "        color = random_color if i in to_show else 'k'\n",
    "        alpha = 1.0 if i in to_show else 0.5\n",
    "        linewidth = 2.0 if i in to_show else 0.7\n",
    "\n",
    "        rect = plt.Rectangle(\n",
    "            (xmin, ymin), w, h,\n",
    "            linewidth=linewidth, edgecolor=color, \n",
    "            facecolor='none', linestyle=linestyle,\n",
    "            alpha=alpha\n",
    "        )\n",
    "        if i in to_show:\n",
    "            ax.plot([cx], [cy], marker='s', markersize=7, color=random_color)\n",
    "        ax.add_patch(rect)\n",
    "    \n",
    "    # why not ax.axis('equal')?\n",
    "    ax.set_ylim([0, HEIGHT])\n",
    "    ax.set_xlim([0, WIDTH])\n",
    "    ax.invert_yaxis()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot(anchor_boxes, cell_to_show=(2, 3))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
